#!/usr/bin/env python
import getopt
import os
import re
import sys


# global variables stored options value
is_debug = False
is_suspend = False
is_resume = False
log_file = ''
djob_path = ''
job_ids = []
job_err_info = []
# custom query result index
JOB_CUSTOM_ID_INDEX = 0
JOB_CUSTOM_TASK_INDEX_INDEX = 1
JOB_CUSTOM_TASK_STATE_INDEX = 2
JOB_CUSTOM_QUERY_LEN = 3
# custom job items
CUSTOM_TASK_INDEX = 'task_index'
CUSTOM_TASK_STATE = 'task_state'

ERR_MSG_FORMAT = '{:>15}    {}'


def get_help_info():
    help_info = '''scontrol [<OPTION>] [<COMMAND>]                                            
    Valid <OPTION> values are:                                             
     --help     show this help message                                             

    Valid <COMMAND> values are:
     resume <jobid_list>      resume previously suspended job
     suspend <jobid_list>     suspend specified job '''
    return help_info


def log(content):
    if is_debug:
        f = open(log_file, 'a+')
        f.write(content + '\n')
        f.close()


def init_log_file_name():
    global log_file
    file_name = '%s.%d.%d' % ('scontrol', os.getuid(), os.getpid())
    log_file = os.path.join('/tmp/', file_name)


# init options from environment
def init_config_with_env():
    global is_debug
    debug_env = os.getenv('SLURM_TO_DONAU_DEBUG')
    if debug_env is not None and debug_env.lower() == 'true':
        is_debug = True
        init_log_file_name()


def get_invalid_option(err_message):
    if 'not recognized' in err_message:
        options = re.findall(r'option (.*) not recognized', err_message)
        if len(options) == 1:
            return options[0]
    return ''


def print_invalid_option(option):
    if len(option) == 0:
        return
    if len(option) > 1 and option[1] == '-':
        print("scontrol: invalid option '" + option + "'")
    else:
        print("scontrol: invalid option -- '" + option[1:] + "'")
    print('''Try "scontrol --help" for more information''')


def parse_job_ids(job_id_str):
    global job_ids
    for id_str in job_id_str:
        id_list = id_str.split(',')
        for id in id_list:
            if not id.isdigit():
                print('Invalid job id specified for job %s' % id_str)
                sys.exit(1)
            else:
                job_ids.append(id)
    all_ids = ' '.join(job_ids)
    log('parse valid job ids : %s' % all_ids)


def parse_command_args(remain_args):
    global is_suspend
    global is_resume
    if len(remain_args) == 0:
        print('scontrol: too few options for scontrol')
        print('Try "scontrol --help" for more information')
        sys.exit(1)
    if remain_args[0] == 'suspend':
        if len(remain_args) == 1:
            print('too few arguments for keyword:suspend')
            sys.exit(1)
        is_suspend = True
        parse_job_ids(remain_args[1:])
    elif remain_args[0] == 'resume':
        if len(remain_args) == 1:
            print('too few arguments for keyword:resume')
            sys.exit(1)
        is_resume = True
        parse_job_ids(remain_args[1:])
    else:
        print('invalid keyword: %s' % remain_args[0])
        sys.exit(1)


def check_result_error(out_line):
    global job_err_info
    # check permission
    if 'access to job' in out_line:
        out_infos = out_line.split()
        if len(out_infos) > 0 and out_infos[0].isdigit():
            err_info = ERR_MSG_FORMAT.format(out_infos[0], 'Access/permission denied')
            job_err_info.append(err_info)
            return False
        job_err_info.append(out_line)
        return False
    if 'not exist' in out_line:
        out_infos = out_line.split()
        if len(out_infos) > 0 and out_infos[0].isdigit():
            err_info = ERR_MSG_FORMAT.format(out_infos[0], 'Unexpected message received for job %s' % out_infos[0])
            job_err_info.append(err_info)
            return False
        job_err_info.append(out_line)
        return False
    if out_line.startswith('---'):
        return False
    out_job_info = out_line.split()
    if len(out_job_info) != JOB_CUSTOM_QUERY_LEN:
        job_err_info.append(out_line)
        return False
    if len(out_job_info) > 0:
        if out_job_info[0] == 'JOB_ID':
            return False
        if out_job_info[0].isdigit():
            return True
    if len(out_job_info) > 0 and out_job_info[0] == 'message':
        err_info = ERR_MSG_FORMAT.format('message', ' '.join(out_job_info[1:]))
        job_err_info.append(err_info)
    return False


def convert_donau_state_to_slurm(donau_state):
    if donau_state == 'SUCCEEDED':
        return 'COMPLETED'
    if donau_state == 'FAILED':
        return 'FAILED'
    if donau_state == 'STOPPED' or donau_state == 'SSTOPPED':
        return 'SUSPENDED'
    if donau_state == 'PENDING' or donau_state == 'WAITING':
        return 'PENDING'
    return donau_state


def check_task_state(id, state_list, task_id_list):
    global job_err_info
    if is_suspend:
        # if donau job with multi tasks, all the tasks should br running to be suspended
        for state, task_id in zip(state_list, task_id_list):
            if state != 'RUNNING':
                slurm_state = convert_donau_state_to_slurm(state)
                err_info = ERR_MSG_FORMAT.format(id,
                                                 'task id %s state %s not correct, suspend job failed' %
                                                 (task_id, slurm_state))
                job_err_info.append(err_info)
                return False
        return True
    elif is_resume:
        # if donau job with multi tasks, all the tasks should br stopped to be suspended
        for state, task_id in zip(state_list, task_id_list):
            if state != 'STOPPED':
                slurm_state = convert_donau_state_to_slurm(state)
                err_info = ERR_MSG_FORMAT.format(id,
                                                 'task id %s state %s not correct, resume job failed' %
                                                 (task_id, slurm_state))
                job_err_info.append(err_info)
                return False
        return True
    return False


def parse_job_id(job_info):
    job_id_re = re.findall(r'access to job (d+) is denied', job_info)
    if len(job_id_re) == 1:
            return job_id_re[0]
    return ''


def append_err_message(job_id, err_output):
    global job_err_info
    err_info = err_output.splitlines()
    for info in err_info:
        if 'successfully' in info:
            continue
        err_msg_list = info.split()
        if len(err_msg_list) <= 1:
            continue
        if err_msg_list[0] == 'JOB_ID':
            continue
        if 'access to job' in info:
            job_err_info.append(ERR_MSG_FORMAT.format(job_id, 'Access/permission denied'))
        if len(err_msg_list) > 2:
            id = err_msg_list[0]
            index = err_msg_list[1]
            if id.isdigit() and index.isdigit():
                job_err_info.append(ERR_MSG_FORMAT.format(id,
                                                          'task %s: %s' % (index, ' '.join(err_msg_list[2:]))))
            elif err_msg_list[0] == 'message':
                job_err_info.append(ERR_MSG_FORMAT.format(job_id,
                                                          ' '.join(err_msg_list[1:])))


def job_control():
    global djob_path
    query_command = '%s -x -CO \"JOB_ID TASK_INDEX TASK_STATE\"' % djob_path
    query_job_ids = ' '.join(job_ids)
    query_command += ' ' + query_job_ids
    log('get query_command: %s' % query_command)

    job_dict = {}
    # execute command and get job list
    output = os.popen(query_command).read()
    log('get output: %s' % output)
    job_info = output.splitlines()
    if len(job_info) == 0:
        return
    if len(job_info) == 1 and 'error' in job_info[0]:
        print(job_info[0])
        sys.exit(1)
    # check result error
    for info in job_info:
        if not check_result_error(info):
            log('check djob result error: %s' % info)
        else:
            # check valid info
            query_result = info.split()
            job_id = query_result[JOB_CUSTOM_ID_INDEX]
            if job_dict.get(job_id) is None:
                job_dict[job_id] = {}
                job_dict[job_id][CUSTOM_TASK_INDEX] = []
                job_dict[job_id][CUSTOM_TASK_STATE] = []
            job_info = job_dict.get(job_id)
            task_index = query_result[JOB_CUSTOM_TASK_INDEX_INDEX]
            task_state = query_result[JOB_CUSTOM_TASK_STATE_INDEX]
            job_info[CUSTOM_TASK_INDEX].append(task_index)
            job_info[CUSTOM_TASK_STATE].append(task_state)
    # check job state
    for id in job_dict.keys():
        job_info = job_dict.get(id)
        if job_info is None:
            continue
        task_index_list = job_info.get(CUSTOM_TASK_INDEX)
        task_state_list = job_info.get(CUSTOM_TASK_STATE)
        if len(task_state_list) == 0 or len(task_index_list) == 0:
            continue
        if len(task_state_list) != len(task_index_list):
            continue
        if not check_task_state(id, task_state_list, task_index_list):
            continue
        # control command
        action_command = djob_path
        if is_suspend:
            action_command += ' -S '
        elif is_resume:
            action_command += ' -R '
        else:
            continue
        task_id_list = job_info.get(CUSTOM_TASK_INDEX)
        if len(task_id_list) == 0:
            continue
        for task_id in task_id_list:
            action_command += ' %s.%s' % (id, task_id)
        log('action_command: %s' % action_command)
        # execute action command
        output = os.popen(action_command).read()
        log('action result: %s' % output)
        append_err_message(id, output)


def check_donau_cli_env():
    sched_cli_home = os.getenv('CCSCHEDULER_CLI_HOME')
    ccs_cli_home = os.getenv('CCS_CLI_HOME')
    if sched_cli_home is None or ccs_cli_home is None:
        print('cli env home is not configured, '
              'source configure \'profile.env\'(bash) or \'cshrc.env\'(csh) first')
        sys.exit(1)
    global djob_path
    djob_path = os.path.join(ccs_cli_home, 'bin', 'djob')
    if not os.path.exists(djob_path):
        print('djob path is not exist,'
              'source configure \'profile.env\'(bash) or \'cshrc.env\'(csh) first')
        sys.exit(1)


if __name__ == '__main__':
    if os.getuid() == 0:
        print('Permission denied. The root user is not allowed to operate.')
        sys.exit(1)
    check_donau_cli_env()
    init_config_with_env()
    short_options = ''
    long_options = ['help', 'usage']
    try:
        opts, args = getopt.getopt(sys.argv[1:], short_options, long_options)
        for opt_name, opt_value in opts:
            if opt_name == '--help' or opt_name == '--usage':
                print(get_help_info())
                sys.exit()
        # check remain args
        parse_command_args(args)
    except getopt.GetoptError as err:
        log('get opt error: %s' % str(err))
        invalid_option = get_invalid_option(str(err))
        if len(invalid_option) != 0:
            print_invalid_option(invalid_option)
            sys.exit(1)
        else:
            print(str(err))

    job_control()
    if len(job_err_info) > 0:
        print('         JOB_ID    MESSAGE')
    for err_info in job_err_info:
        print(err_info)
